import numpy as np
from numba import vectorize, njit

from toolkit.utils import interpolate_y
from toolkit.het_block import het


@het(exogenous='Pi', policy='a', backward='uc')
def household(uc_p, Pi_p, a_grid, e_grid, beta_grid, T, tax, rpost, rsub, sigma, alpha, nu, vphi, c_const, n_const,
              w, ssflag=False):
    """Single backward iteration step using endogenous gridpoint method for households with GHH+ utility.

    T has to have the same dimension as e_grid.
    tax has to have the same dimension as e_grid.
    """
    # this one is useful to do internally
    ws = w * (1 - tax) * e_grid

    # uc(z_t, a_t)
    uc_nextgrid = ((1 + rsub) * beta_grid * Pi_p) @ uc_p

    # c(z_t, a_t) and n(z_t, a_t)
    c_nextgrid, n_nextgrid = cn(uc_nextgrid, ws[:, np.newaxis], sigma, alpha, nu, vphi)

    # c(z_t, a_{t-1}) and n(z_t, a_{t-1})
    lhs = c_nextgrid - ws[:, np.newaxis] * n_nextgrid + a_grid[np.newaxis, :] - T[:, np.newaxis]
    rhs = (1 + rpost) * a_grid
    n = interpolate_y(lhs, rhs, n_nextgrid)
    c = interpolate_y(lhs, rhs, c_nextgrid)

    # test constraints, replace if needed
    a = rhs + ws[:, np.newaxis] * n + T[:, np.newaxis] - c
    iconst = np.nonzero(a < a_grid[0])
    a[iconst] = a_grid[0]

    if ssflag:
        # use precomputed values
        c[iconst] = c_const[iconst]
        n[iconst] = n_const[iconst]
    else:
        # have to solve again if in transition
        n_ghh = np.ones((e_grid.shape[0], a_grid.shape[0])) * (ws[:, np.newaxis] / vphi) ** (1 / nu)
        uc_seed = (c_const[iconst] - vphi * alpha * n_ghh[iconst] ** (1 + nu) / (1 + nu)) ** (-sigma)
        c[iconst], n[iconst] = solve_cn(ws[iconst[0]],
                                        rhs[iconst[1]] + T[iconst[0]] - a_grid[0], sigma, alpha, nu, vphi, uc_seed)

    # calculate marginal utility to go backward
    uc = (c - vphi * alpha * n ** (1 + nu) / (1 + nu)) ** (-sigma)

    # efficiency units of labor (for production)
    ns = e_grid[:, np.newaxis] * n

    # tax revenues (for gov budget)
    rev = (w * tax[:, np.newaxis]) * ns

    # after-tax income (for MPE calculations)
    tns = ws[:, np.newaxis] * n

    return uc, a, c, n, ns, rev, tns


@njit
def cn(uc, w, sigma, alpha, nu, vphi):
    """Return optimal c, n as function of uc given parameters"""
    n = (w * uc / (vphi * alpha * uc + vphi * (1 - alpha))) ** (1 / nu)
    c = uc ** (-1 / sigma) + vphi * alpha * n ** (1 + nu) / (1 + nu)
    return c, n


@njit
def netexp(log_uc, w, T, sigma, alpha, nu, vphi):
    """Return net expenditure as a function of log uc and its derivative."""
    uc = np.exp(log_uc)
    c, n = cn(uc, w, sigma, alpha, nu, vphi)
    ne = c - w * n - T

    # elasticities of c and n wrt log uc
    n_loguc = n / nu * (1 - alpha) / (1 - alpha + alpha * uc)
    c_loguc = vphi * alpha * n ** nu * n_loguc - uc ** (-1/sigma) / sigma
    netexp_loguc = c_loguc - w * n_loguc

    return ne, netexp_loguc


@vectorize
def solve_uc(w, T, sigma, alpha, nu, vphi, uc_seed):
    """Solve for optimal uc given in log uc space.

    max_{c, n} (c - vphi*alpha*n**(1+nu)/(1+nu)) ** (1-sigma) - vphi*(1-alpha)*n**(1+nu)/(1+nu) s.t. c = w*n + T
    """
    log_uc = np.log(uc_seed)
    for i in range(100):
        ne, ne_p = netexp(log_uc, w, T, sigma, alpha, nu, vphi)
        if abs(ne) < 1E-11:
            break
        else:
            log_uc -= ne / ne_p
    else:
        raise ValueError("Cannot solve constrained household's problem: No convergence after 100 iterations!")

    return np.exp(log_uc)


def solve_cn(w, T, sigma, alpha, nu, vphi, uc_seed):
    uc = solve_uc(w, T, sigma, alpha, nu, vphi, uc_seed)
    return cn(uc, w, sigma, alpha, nu, vphi)


'''Part 2: MPCs, MPEs, EIS, CI.'''


@njit
def elasticities(uc, w, sigma, alpha, nu, vphi):
    """Compute EIS and CI."""
    c, n = cn(uc, w, sigma, alpha, nu, vphi)
    n_loguc = n / nu * (1 - alpha) / (1 - alpha + alpha * uc)
    c_loguc = vphi * alpha * n ** nu * n_loguc - uc ** (-1/sigma) / sigma
    eis = - c_loguc / c
    ci = alpha / (alpha + (1 - alpha) / uc)

    return eis, ci


def mpcs_mpes(c, a, tns, a_grid, r, eis, ci, nu):
    """Approximate MPC and MPE, with symmetric differences where possible."""
    mpcs, mpes = np.empty_like(c), np.empty_like(c)
    post_return = (1 + r) * a_grid

    # symmetric difference in the interior
    mpcs[:, 1:-1] = (c[:, 2:] - c[:, 0:-2]) / (post_return[2:] - post_return[:-2])
    mpes[:, 1:-1] = -(tns[:, 2:] - tns[:, 0:-2]) / (post_return[2:] - post_return[:-2])

    # forward difference at the lower bound
    mpcs[:, 0] = (c[:, 1] - c[:, 0]) / (post_return[1] - post_return[0])
    mpes[:, 0] = -(tns[:, 1] - tns[:, 0]) / (post_return[1] - post_return[0])

    # backward difference at the upper bound
    mpcs[:, -1] = (c[:, -1] - c[:, -2]) / (post_return[-1] - post_return[-2])
    mpes[:, -1] = -(tns[:, -1] - tns[:, -2]) / (post_return[-1] - post_return[-2])

    # next try constrained case, implementing our formula and using mpe+mpc=1
    iconst = np.nonzero(a == a_grid[0])
    mpe_mpc_ratio_constrained = tns[iconst] / c[iconst] / nu / eis[iconst] * (1 - ci[iconst])

    mpcs[iconst] = 1 / (1 + mpe_mpc_ratio_constrained)
    mpes[iconst] = 1 - 1 / (1 + mpe_mpc_ratio_constrained)

    return mpcs, mpes
